import torch
from PIL import Image
import sys
from ..smp import *
from .base import BaseModel
from ..dataset import DATASET_TYPE
from transformers import AutoModel, GenerationConfig


class WeMM(BaseModel):
    def __init__(self, model_path="feipengma/WeMM", **kwargs):
        self.wemm = AutoModel.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
        )
        self.wemm.cuda()
        self.wemm.eval()
        torch.cuda.empty_cache()

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if DATASET_TYPE(dataset) == "MCQ":
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert dataset is None or isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)
        question = line["question"]
        hint = line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
        if hint is not None:
            question = hint + "\n" + question
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f"\n{key}. {item}"
        prompt = question

        if len(options):
            prompt += (
                "\n请直接回答选项字母。"
                if cn_string(prompt)
                else "\nAnswer with the option's letter from the given choices directly."
            )
        else:
            prompt += (
                "\n请直接回答问题。"
                if cn_string(prompt)
                else "\nAnswer the question directly."
            )

        message = [dict(type="text", value=prompt)]
        message.extend([dict(type="image", value=p) for p in tgt_path])
        return message

    def generate_inner(self, message, dataset=None):
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)

        if dataset == "HallusionBench":
            prompt = (
                prompt
                + " Please answer yes or no. Answer the question using a single word or phrase."
            )

        gen_config = None
        if dataset == "MMVet":
            gen_config = GenerationConfig(
                max_new_tokens=512,
                do_sample=True,
                temperatures=0.7,
                num_beams=3,
                eos_token_id=self.wemm.tokenizer.eos_token_id,
                pad_token_id=(
                    self.wemm.tokenizer.pad_token_id
                    if self.wemm.tokenizer.pad_token_id is not None
                    else self.wemm.tokenizer.eos_token_id
                ),
            )
        pred = self.wemm.mm_generate(image_path, prompt, gen_config)

        return pred
